def config_time_ax(ax):
def format_func(value, tick_number):
return '%.2f' % (value / 22050)
ax.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
def draw_waves(waves, names):
assert(len(waves) == len(names))
n = len(waves)
assert(n >= 2)
fig, axes = plt.subplots(n, dpi=100, figsize=(8, 6))
for i, (ax, wave, name) in enumerate(zip(axes, waves, names)):
config_time_ax(ax)
ax.set_title(name)
if i == n - 1:
ax.set_xlabel('Time (seconds)')
ax.plot(wave[:int(22050 * 0.1)])
plt.tight_layout()
plt.show()
plt.close()
conv_ids = [0, 9, 19, 29]
channel_ids = [0, 63, 127]
for dir_id, (dir_name, in_dir) in enumerate(zip(['Maestro', 'Nsynth'],
['maestro-intermediate-examples',
'nsynth-intermediate-examples'])):
wave_out = np.load(['long-maestro-examples/0.npy', 'long-nsynth-examples/0.npy'][dir_id])
display(HTML('<h1>%s</h1>' % dir_name))
for conv_id in conv_ids:
display(HTML('<h2>Conv layer %d</h2>' % conv_id))
for channel_id in channel_ids:
display(HTML('<h3>Channel %d</h3>' % channel_id))
def gen_wave(out_type):
if out_type == 'final':
return wave_out
else:
path = os.path.join(in_dir, 'conv-%d-channel-%d-%s.npy' % (conv_id,
channel_id,
out_type))
return np.load(path)
out_names = ['Tanh', 'Sigmoid', 'Final output']
out_types = ['tanh', 'sigmoid', 'final']
for out_name, out_type in zip(out_names, out_types):
display(HTML('<h4>%s</h4>' % out_name))
# path = gen_path(out_type)
# if not os.path.exists(path):
# continue
wave = gen_wave(out_type)# np.load(path)
display(Audio(data=wave, rate=22050))
waves = [gen_wave(out_type) for out_type in out_types]
draw_waves(waves, out_names)
display(HTML('<hr style="border: 1px solid grey;">'))